"""

    NNSDE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false,
        batch = true, sub_batch = 1, strong_loss = false,
        moment_loss = false, param_estim = false, dataset = [],
        data_sub_batch = 1, numensemble = 10, additional_loss = nothing, kwargs...)

This is an algorithm for solving stochastic ordinary differential equations using a specialization of physics-informed neural networks (PINNs).
Allows users to solve standard `SDEProblem`s using a Stochastic PINN (SPINN) solver.

!!! warning

    NNSDE only supports SDEs which are written in an out-of-place form, i.e.
    `du = f(u,p,t)`, and not `f(du,u,p,t)`. If not declared out-of-place, then the NNSDE algorithm
    will exit with an error.

## Positional Arguments

* `chain`: A neural network (NN) architecture specific to SPINNs, such that the input dimensions correspond
            to time and `n` independent random variables chosen as the diffusion term's - Wiener process orthogonal random basis for it's KKL expansion.
            `n` has to be chosen by the user, depending on how accurately they want to represent stochasticity via the eigenvalues of the SDE's KKL expansion as in the SPINN loss function.
            The chain is defined as a `Lux.AbstractLuxLayer` or `Flux.Chain`. `Flux.Chain` will be converted to `Lux` using
           `adapt(FromFluxAdaptor(), chain)`.
* `opt`: The optimizer to train the neural network.
* `init_params`: The initialization scheme for the neural network. By default, this is `nothing`
                 which thus uses the random initialization provided by the neural network
                 library.

## Keyword Arguments

* `strategy`: The training strategy used to choose the points for the evaluations.
              Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no
              `dt` is given, and `GridTraining` is used with `dt` if given.
              For the SDE solver, GridTraining is recommended for better weak solution estimates.

* `autodiff`: The switch between automatic and numerical differentiation for
              the PDE operators. The reverse mode of the loss function is always
              automatic differentiation (via Zygote), this is only for the derivative
              in the loss function (the derivative with respect to time).

* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural
           network is applied at a row vector of values `t` simultaneously, i.e. it's the
           batch size for the neural network evaluations. This requires a neural network
           compatible with batched data. `false` means which means the application of the
           neural network is done at individual time points one at a time. This is not
           applicable to `QuadratureTraining` where `batch` is passed in the `strategy`
           which is the number of points it can parallelly compute the integrand.

* `sub_batch`: Interpretation depends on the training type chosen (based on `strong_loss` arg).
               In case of weak loss, training - this is the number of samples for each random coefficient `z_i` per timepoint to be taken,
               here a higher sub_batch results in almost always a better capture of the weak solution. A defining feature of the loss and
               training strategy is that we construct training paths by taking separate, independent sets of `z_i` for each timepoint,
               and this is done `n = sub_batch` times for each timepoint.
               We essentially construct paths with Monte Carlo ensembles of random coefficients per time step, rather than continuous temporally consistent paths.
               In case of strong loss, training - this is the number of solution paths we are training over, the final SDEPINN solution is these fixed strong solution paths.
               Is `1` by default.

* `strong_loss`: Controls the choice of training via the loss function aggregator operator and training type.
                 If `true`, the loss has a strong form - summation across timepoints and `n` fixed (In each path, the same set of `z_i` coefficients for each timepoint) training solution paths (controlled via `sub_batch`),
                 The solution returned is a strong solution for selected paths.
                 If `false`, it is a weak loss - where `n = sub_batch` training solution paths are generated by random sampling across `z_i` probability space (In each path, independent random coeff `z_i` values for each timepoint).
                 The solution returned is a weak solution for the SDE. 
                 Allows choosing weak/strong training discretization of the time and Gaussian coefficients domains. 
                 Note that weak training is almost always faster but if one is interested in pathwise solutions of the SDE then `strong_loss` must be `true`.
                 It is `false` by default.

* `moment_loss`: Allows user to include a moment matching loss (1st and 2nd - mean and var) for the solver. It is calculated against the dataset provided.
                 Is `false` by default.

* `param_estim`: Boolean to indicate whether parameters of the differential equations are
                 learnt along with parameters of the neural network.

* `dataset`: A dataset used to train the SDEPINN using observed process samples, at respective timepoints.
             The `L2` loss, `moment_loss` is created using this. It is a Vector of `x`, `t` where `x` is a nested vector of multiple observations of the adapted process being learnt,
             here each inner vector corresponds to one strong solution set of timeseries observations. `t` is a vector of timepoints at which we have the multiple `x` observations.

* `data_sub_batch`:  The number of sets of random coefficients z_i to be taken for each timepoint while matching moments in `moment_loss`.
                     The moment matching is done against strong process observations, therefore to match moments we must have corresponding SDEPINN outputs and inputs.
                     Since realistically we cant get the eigenfunction, random coefficient decompositions for each observation we try to approximate the final moments by taking mean/sum
                     over multiple sets of z_i concatenated with timepoints as SDEPINN inputs to get outputs that we can finally match with the dataset mean and variance.
                     Is internally `max(data_sub_batch, length(dataset[1]))` by default.

* `numensemble`: The solver returns an ensemble results/weak solution over user's provided saveat discretization. `numensemble` controls the Number
                 of solution predictions to take an ensemble over for each timepoints. Is `10` by default.

* `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network
                     trial solutions, θ are the weights of the neural network(s).

* `kwargs`: Extra keyword arguments are splatted to the Optimization.jl `solve` call.

## Examples

```julia
u0 = [1.0, 1.0]
ts = [t for t in 1:100]
(u_, t_) = (analytical_func(ts), ts)
function additional_loss(phi, θ)
    return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_)
end
alg = NNSDE(chain, opt, additional_loss = additional_loss)
```

```julia
u₀ = 0.5
f(u, p, t) = 1.2 * u
g(u, p, t) = 1.1 * u
tspan = (0.0, 1.0)
prob = SDEProblem(f, g, u₀, tspan)
n_z = 3
dim = 1 + n_z
luxchain = Chain(Dense(dim, 16, σ), Dense(16, 16, σ), Dense(16, 1))
opt = BFGS()
sol = solve(prob, NNSDE(luxchain, opt), verbose = true, dt = 1 / 50.0f0, abstol = 1e-10, maxiters = 200)
```

## Solution Notes

Note that the returned weak solution is evaluated at fixed time points according to standard output
handlers such as `saveat` and `dt`. However, the neural network is a fully continuous
solution so `sol(t)` is an accurate interpolation (up to the neural network training
result). In addition, the `OptimizationSolution` is returned as `sol.k` for further
analysis.

## References

Stochastic Physics-Informed Neural Ordinary Differential Equations : https://arxiv.org/abs/2109.01621
Stochastic PDE Functionality #531 : https://github.com/SciML/NeuralPDE.jl/issues/531

"""

@concrete struct NNSDE
    chain <: AbstractLuxLayer
    opt
    init_params
    strategy <: Union{Nothing, AbstractTrainingStrategy}
    autodiff::Bool
    batch::Bool
    sub_batch::Int64
    strong_loss::Bool
    moment_loss::Bool
    param_estim::Bool
    dataset <: Union{Vector, Vector{<:Vector}}
    data_sub_batch::Int64
    numensemble::Number
    additional_loss <: Union{Nothing, Function}
    kwargs
end

function NNSDE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false,
        batch = true, sub_batch = 1, strong_loss = false,
        moment_loss = false, param_estim = false, dataset = [],
        data_sub_batch = 1, numensemble = 10, additional_loss = nothing, kwargs...)
    chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))
    return NNSDE(
        chain, opt, init_params, strategy, autodiff, batch, sub_batch, strong_loss, moment_loss,
        param_estim, dataset, data_sub_batch, numensemble, additional_loss, kwargs)
end

"""
    SDEPhi(chain::Lux.AbstractLuxLayer, t, u0, st)

Internal struct, used for representing the SDE solution as a neural network in a form that
respects boundary conditions, i.e. `phi(inp) = u0 + inp[1] * NN(inp)`.
"""
@concrete struct SDEPhi
    u0
    t0
    smodel <: StatefulLuxLayer
end

function SDEPhi(model::AbstractLuxLayer, t0::Number, u0, st)
    return SDEPhi(u0, t0, StatefulLuxLayer{true}(model, nothing, st))
end

function (f::SDEPhi)(inp, θ)
    dev = safe_get_device(θ)
    return f(dev, safe_expand(dev, inp), θ)
end

# single timepoint, single sample (t[i],z_1,z_2...) input.
function (f::SDEPhi)(dev, inp::Vector{<:Number}, θ)
    res = only(cdev(f.smodel(dev(inp), θ.depvar)))
    return f.u0 .+ (inp[1] - f.t0) .* res
end

# single timepoint, multiple Matrix samples input.
function (f::SDEPhi)(dev, inp::Matrix{<:Number}, θ)
    return dev(f.u0) .+ ((inp[1, :] .- f.t0)' .* f.smodel(dev(inp), θ.depvar))
end

function generate_phi(chain::AbstractLuxLayer, t, u0, ::Nothing)
    θ, st = LuxCore.setup(Random.default_rng(), chain)
    return SDEPhi(chain, t, u0, st), θ
end

function generate_phi(chain::AbstractLuxLayer, t, u0, init_params)
    st = LuxCore.initialstates(Random.default_rng(), chain)
    return SDEPhi(chain, t, u0, st), init_params
end

"""
    ∂u_∂t(phi, inp, θ, autodiff)

Computes for SDE's solution `u`, `u'` using either forward-mode automatic differentiation or numerical differentiation.
It is the partial derivative/Jacobian of SDEPINN wrt first input time.
"""
function ∂u_∂t end

# du_dt is defined for single timepoint Matrix inputs. (refer inner_sde_loss for the Matrix form details).
function ∂u_∂t(
        inputs::Matrix{<:Number}, operator_info::Tuple{
            SDEPhi, ComponentArrays.ComponentVector, Bool})
    phi, θ, autodiff = operator_info
    autodiff &&
        return ForwardDiff.jacobian(t -> phi(vcat(t, inputs[2:end, :]), θ), inputs[1, :])

    ϵ = sqrt(eps(eltype(inputs)))
    return (phi(vcat(inputs[1, :]' .+ ϵ, inputs[2:end, :]), θ) .- phi(inputs, θ)) ./ ϵ
end

"""
    inner_sde_loss(phi, f, autodiff, inputs, θ, p, param_estim, train_type)

Simple L2 inner loss for the SDE at a time `t` and random variables z_i with parameters `θ` of the neural network.
    
for non batching case we take in single timepoint input per call:
    - inputs is a NN_input_dims x n_samples matrix.

for batching case we take in a Vector of timepoints as inputs per call:
    - inputs is n=n_timepoints sized Vector of NN_input_dims x n_samples matrices.

train_type allows for solving across a few strong paths of the SDE solution OR capturing the whole expected SDE weak solution.
Note: NNODE, NNSDE take only a single Neural Network which is multioutput or singleoutput
"""
function inner_sde_loss end

# no batching across time
function inner_sde_loss(
        phi::SDEPhi, f, g, autodiff::Bool, inputs::P,
        θ, p, param_estim::Bool, train_type) where {P <: Matrix{<:Number}}
    p_ = param_estim ? θ.p : p

    # u is NN_output_dims x n_samples matrix.
    u = phi(inputs, θ)
    n_inp = phi.smodel.model[1].in_dims - 1
    n_samples = size(inputs)[2]

    # inputs is a NN_input_dims x n_samples matrix, first row being the time domain discretized point.
    # each timepoint in domain has an associated (1+n_z) x n_samples Matrix of NN inputs.
    fs = if phi.u0 isa Number
        reduce(hcat,
            [f(u[:, i][1], p_, inputs[1, i]) +
             g(u[:, i][1], p_, inputs[1, i]) * √2 *
             sum(inputs[1 + j, i] * cos((j - 1 / 2)pi * inputs[1, i]) for j in 1:n_inp)
             for i in 1:n_samples])
    else
        # multioutput case.
        reduce(hcat,
            [f(u[:, i], p_, inputs[1, i]) +
             g(u[:, i], p_, inputs[1, i]) * √2 *
             sum(inputs[1 + j, i] * cos((j - 1 / 2)pi * inputs[1, i]) for j in 1:n_inp)
             for i in 1:n_samples])
    end

    # dudt is jacobian matrix NN_output_dims x n_samples.
    dudt = ∂u_∂t(inputs, (phi, θ, autodiff))

    # sum call is over multioutputs, train_type called over dims=2/n_samples axis.
    return sum(train_type(abs2.(fs .- dudt), dims = 2))
end

# batching case
function inner_sde_loss(
        phi::SDEPhi, f, g, autodiff::Bool, inputs::P,
        θ, p, param_estim::Bool, train_type) where {P <:
                                                    Vector{<:Matrix{<:Number}}}
    p_ = param_estim ? θ.p : p
    n_inp = phi.smodel.model[1].in_dims - 1
    # quadrature alg call case handling
    n_samples = isempty(inputs) ? 0 : size(inputs[1])[2]

    # same dims as inputs, u[i] now being a NN_output_dims x n_samples Matrix
    u = Base.Fix2(phi, θ).(inputs)

    # inputs is a Vector of NN_input_dims x n_samples matrices, each timepoint[i] -> inputs[i] Matrix{<:Number}.
    # each timepoint in domain has an associated (1+n_z) x n_samples Matrix of NN inputs. 
    fs = if phi.u0 isa Number
        [reduce(hcat,
             [f(u[k][:, i][1], p_, inputs[k][1, i]) +
              g(u[k][:, i][1], p_, inputs[k][1, i]) * √2 *
              sum(inputs[k][1 + j, i] * cos((j - 1 / 2)pi * inputs[k][1, i])
              for j in 1:n_inp) for i in 1:n_samples]) for k in eachindex(inputs)]
    else
        # multioutput case
        [reduce(hcat,
             [f(u[k][:, i], p_, inputs[k][1, i]) +
              g(u[k][:, i], p_, inputs[k][1, i]) * √2 *
              sum(inputs[k][1 + j, i] * cos((j - 1 / 2)pi * inputs[k][1, i])
              for j in 1:n_inp) for i in 1:n_samples]) for k in eachindex(inputs)]
    end

    # same dims as inputs, dudt[i] now being a NN_output_dims x n_samples Matrix
    dudt = Base.Fix2(∂u_∂t, (phi, θ, autodiff)).(inputs)

    return sum(sum(train_type(abs2.(fs[i] .- dudt[i]), dims = 2))
    for i in eachindex(inputs)) / length(inputs)
end

"""
    add_rand_coeff(times, n_z, sub_batch)

For `n = sub_batch` independent random coeff/basis variables `z_i`, values at each timepoint. 
This is a weak training discretization of the time domain & similar to Monte Carlo sampling across z_i probability spaces.
n_z is the number of Independent Random basis vectors for the Brownian's probability space used for truncating via the KKl expansion.
returns a list appending `n = n_z`` sampled (Uniform Gaussian) values to a fixed time's value or a list of times.
"""
# strategy 1 -> train for the expected behaviour of SDE solution.
function add_rand_coeff(times::P, n_z::Int64, sub_batch::Int64) where {P <: Number}
    return reduce(hcat, [vcat(times, rand(Normal(0, 1), n_z)) for i in 1:sub_batch])
end

function add_rand_coeff(times::Vector, n_z::Int64, sub_batch::Int64)
    return [reduce(hcat, [vcat(time, rand(Normal(0, 1), n_z)) for i in 1:sub_batch])
            for time in times]
end

"""
    add_rand_coeff_2(times, n_z, num_samples)

For `n = n_samples` strong paths (`z_i`... are the same per sample) - strong training discretization over input domain.
`n_z` is the number of Independent Random basis vectors for the Brownian's probability space used for truncating via the KKl expansion.
returns a list appending `n = n_z` sampled (Uniform Gaussian) random variables values to a list of times.
"""
# strategy 2 -> train over n = num_samples strong realisations of the process.
function add_rand_coeff_2(times, n_z::Int64, num_samples)
    # each timepoint is paired with a set of fixed n_z random coefficients.
    # This is defined via a Filtration on the estimated probability space for the adapted process.
    # t is fixed, therefore eigen-functions,values are fixed with it. W (therefore z_i) is not fixed so z_i sampled randomnly.

    zi_samples = [rand(Normal(0, 1), n_z) for i in 1:num_samples]
    return [reduce(hcat, [vcat(time, zi_samples[i]) for i in 1:num_samples])
            for time in times]
end

"""
    generate_DataMoments_loss(dataset, phi, n_output)

Returns a function that computes a `L2` loss between the neural network's output and the dataset provided.
Naive moment matching loss for mean and variance of SDEPINN vs Dataset (works best with low sub_batches + strong loss).
Assumes direct moment matching of 1st and 2nd moments captures the solution and that behaviour of `z_i` can be averaged sufficiently for n = `data_sub_batches` inputs.
"""
function generate_DataMoments_loss(
        dataset::Vector{<:Vector}, n_z::Int64, phi::SDEPhi, f, g,
        autodiff::Bool, p, param_estim::Bool, data_sub_batch::Int64, train_type)
    # n_timepoints x data_sub_batch Matrix
    process = reduce(hcat, dataset[1])
    # n_timepoints sized Vector
    ts = dataset[2]

    # construct NN inputs of form [t,n_i] for the physics loss to be applied on the dataset points.
    sdephi_inputs = train_type == sum ? add_rand_coeff_2(ts, n_z, data_sub_batch) :
                    add_rand_coeff(ts, n_z, data_sub_batch)

    # moment matching (MSE across time for 1st, 2nd moments) - assumes diffusion is a Gaussian at each timepoint
    # uses sample variance
    return (θ,
        _) -> begin
        sum(abs2,
            mean(process, dims = 2) .-
            mean.(Base.Fix2(phi, θ).(sdephi_inputs))) / length(ts) +
        # train_type is strong in inner loss as we have data_sub_batch strong solutions.
        abs2(
            inner_sde_loss(
            phi, f, g, autodiff, sdephi_inputs, θ, p, param_estim, train_type)
        ) +
        # get variances across cols - realisations (data_sub_batches)
        # then sum over times, to get each (timepoint, solution) -> variance
        sum(abs2,
            sum(abs2.(process .- mean(process, dims = 2)), dims = 2) .-
            sum(
                abs2.(reduce(vcat, Base.Fix2(phi, θ).(sdephi_inputs)) .-
                      mean.(Base.Fix2(phi, θ).(sdephi_inputs))),
                dims = 2)) /
        (length(ts) * (data_sub_batch - 1)^2)
    end,
    sdephi_inputs
end

"""
    generate_EM_L2loss(dataset, phi, n_output)

Returns a loss function using the provided dataset that work using collocation and moment matching.
The Moment matching loss is based on the Euler Maruyama discretization scheme for SDE solution increments.
The observed solution increments `Xt+1 - Xt` follow the Normal(`f(u,p,t) * Δt`, `g(u,p,t)^2 * Δt`) Distribution as `Δt -> 0`.
(The above follows as Wiener increments are independent & Normally Distributed with variance = `Δt`).
This loss assumes Gaussian increments for the observed SDE solution process and therefore accuracy increases at the time increments are smaller.
"""
function get_increments(x::Vector{T}) where {T}
    @views return x[2:end] .- x[1:(end - 1)]
end

function generate_EM_L2loss(dataset::Vector{<:Vector}, f, g)
    # n_timepoints-1 sized Vector
    process = reduce(hcat, dataset[1])
    Δt = get_increments(dataset[2])
    # n_timepoints-1 x n_x_observations Matrix
    X_increments = reduce(hcat, get_increments.(dataset[1]))
    n, n_samples = size(X_increments)

    loss_fn = (θ,
        _) -> begin
        gx = reduce(hcat,
            [[g(process[i, j], θ.p, dataset[2][i])^2 * Δt[i] for i in 1:n]
             for j in 1:n_samples])
        fx = reduce(hcat,
            [[f(process[i, j], θ.p, dataset[2][i]) * Δt[i] for i in 1:n]
             for j in 1:n_samples])

        # loss based on moments of the Gaussian Increments
        return sum(abs2, X_increments .- fx) + sum(abs2, abs2.(X_increments .- fx) .- gx)
    end

    return loss_fn, nothing
end

"""
    generate_loss(strategy, phi, f, g, autodiff, tspan, n_z, sub_batch, train_type, p, batch, param_estim)

Representation of the loss function, parametric on the training strategy `strategy`.
"""
function generate_loss(
        strategy::QuadratureTraining, phi, f, g, autodiff::Bool, tspan, n_z::Int64, sub_batch::Int64, train_type, p,
        batch::Bool, param_estim::Bool)
    inputs = AbstractVector{Any}[]
    zt_samples = [rand(Normal(0, 1), n_z) for i in 1:sub_batch]

    function integrand(t::Number, θ)
        inputs = train_type == sum ?
                 reduce(hcat, [vcat(time, zt_samples[i]) for i in 1:sub_batch]) :
                 add_rand_coeff(t, n_z, sub_batch)
        return abs2(inner_sde_loss(
            phi, f, g, autodiff, inputs, θ, p, param_estim, train_type))
    end

    # when ts is a 1D Array
    function integrand(ts::Vector, θ)
        inputs = train_type == sum ? add_rand_coeff_2(ts, n_z, sub_batch) :
                 add_rand_coeff(ts, n_z, sub_batch)
        return [abs2(inner_sde_loss(
                    phi, f, g, autodiff, input, θ, p, param_estim, train_type))
                for input in inputs]
    end

    function loss(θ, _)
        intf = BatchIntegralFunction(integrand, max_batch = strategy.batch)
        intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ)
        sol = solve(intprob, strategy.quadrature_alg; strategy.abstol,
            strategy.reltol, strategy.maxiters)
        return sol.u
    end

    return loss, inputs
end

function generate_loss(
        strategy::GridTraining, phi, f, g, autodiff::Bool, tspan, n_z::Int64, sub_batch::Int64,
        train_type, p, batch::Bool, param_estim::Bool)
    ts = collect(tspan[1]:(strategy.dx):tspan[2])

    # n_timepoints * (1+n_z) * n_samples -> Vector{Matrix{Float64}}
    inputs = train_type == sum ? add_rand_coeff_2(ts, n_z, sub_batch) :
             add_rand_coeff(ts, n_z, sub_batch)

    autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
    batch &&
        return (θ,
            _) -> inner_sde_loss(
            phi, f, g, autodiff, inputs, θ, p, param_estim, train_type),
        inputs
    return (θ,
        _) -> sum([inner_sde_loss(phi, f, g, autodiff, input, θ, p,
                       param_estim, train_type)
                   for input in inputs]),
    inputs
end

function generate_loss(strategy::StochasticTraining, phi, f, g, autodiff::Bool,
        tspan, n_z::Int64, sub_batch::Int64, train_type, p, batch::Bool, param_estim::Bool)
    autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining."))
    inputs = AbstractVector{Any}[]

    return (θ,
        _) -> begin
        T = promote_type(eltype(tspan[1]), eltype(tspan[2]))
        ts = ((tspan[2] - tspan[1]) .* rand(T, strategy.points) .+ tspan[1])
        inputs = train_type == sum ? add_rand_coeff_2(ts, n_z, sub_batch) :
                 add_rand_coeff(ts, n_z, sub_batch)

        if batch
            inner_sde_loss(
                phi, f, g, autodiff, inputs, θ, p, param_estim, train_type)
        else
            sum([inner_sde_loss(phi, f, g, autodiff, input, θ, p,
                     param_estim, train_type)
                 for input in inputs])
        end
    end,
    inputs
end

function generate_loss(
        strategy::WeightedIntervalTraining, phi, f, g, autodiff::Bool, tspan, n_z::Int64, sub_batch::Int64, train_type, p,
        batch::Bool, param_estim::Bool)
    autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining."))
    minT, maxT = tspan
    weights = strategy.weights ./ sum(strategy.weights)
    N = length(weights)
    difference = (maxT - minT) / N

    ts = eltype(difference)[]
    for (index, item) in enumerate(weights)
        temp_data = rand(1, trunc(Int, strategy.points * item)) .* difference .+ minT .+
                    ((index - 1) * difference)
        append!(ts, temp_data)
    end
    inputs = train_type == sum ? add_rand_coeff_2(ts, n_z, sub_batch) :
             add_rand_coeff(ts, n_z, sub_batch)

    batch &&
        return (θ,
            _) -> inner_sde_loss(
            phi, f, g, autodiff, inputs, θ, p, param_estim, train_type),
        inputs
    return (θ,
        _) -> sum([inner_sde_loss(phi, f, g, autodiff, input, θ, p,
                       param_estim, train_type)
                   for input in inputs]),
    inputs
end

function evaluate_tstops_loss(
        phi, f, g, autodiff::Bool, tstops, n_z::Int64, sub_batch::Int64,
        train_type, p, batch::Bool, param_estim::Bool)
    inputs = train_type == sum ? add_rand_coeff_2(ts, n_z, sub_batch) :
             add_rand_coeff(ts, n_z, sub_batch)

    batch &&
        return (θ,
            _) -> inner_sde_loss(
            phi, f, g, autodiff, inputs, θ, p, param_estim, train_type),
        inputs
    return (θ,
        _) -> sum([inner_sde_loss(phi, f, g, autodiff, input, θ, p,
                       param_estim, train_type)
                   for input in inputs]),
    inputs
end

function generate_loss(::QuasiRandomTraining, phi, f, g, autodiff::Bool,
        tspan, n_z::Int64, sub_batch::Int64, train_type, p, batch::Bool, param_estim::Bool)
    error("QuasiRandomTraining is not supported by NNODE since it's for high dimensional \
           spaces only. Use StochasticTraining instead.")
end

@concrete struct NNSDEInterpolation
    phi <: SDEPhi
    θ
end

(f::NNSDEInterpolation)(inp, ::Nothing, ::Type{Val{0}}, p, continuity) = f.phi(inp, f.θ)
(f::NNSDEInterpolation)(inp, idxs, ::Type{Val{0}}, p, continuity) = f.phi(inp, f.θ)[idxs]

function (f::NNSDEInterpolation)(
        inp::Array{<:Number, 1}, ::Nothing, ::Type{Val{0}}, p, continuity)
    out = f.phi(inp, f.θ)
    return DiffEqArray([out[:, i] for i in axes(out, 2)], inp)
end

function (f::NNSDEInterpolation)(
        inp::Array{<:Number, 1}, idxs, ::Type{Val{0}}, p, continuity)
    out = f.phi(inp, f.θ)
    return DiffEqArray([out[idxs, i] for i in axes(out, 2)], inp)
end

SciMLBase.interp_summary(::NNSDEInterpolation) = "Trained neural network interpolation"
SciMLBase.allowscomplex(::NNSDE) = true

"""
    SDEsol

Container for solutions, parameter estimates, and training/validation data from an SDE-PINN or related inverse problem solve.

# Fields

* `original` : The OptimizationSolution object got from solving the OptimizationFunction.

* `rode_solution` : RODESolution object containing solve interpolation, interp SDEPINN objects and other related data, metadata.

* `estimated_sol` : Probabilistic estimate of the SDE weak solution using `MonteCarloMeasurements.Particles` from the ensemble fits.
                    It is a nested vector of solutions for each output. Created using an ensemble solution from the set of optimized parameters.

* `timepoints` : A vector of discrete time points over which the solutions are evaluated.

* `estimated_params` : A vector of Estimated SDE parameters (got via the data-driven fit). `nothing` if not performing parameter estimation.

* `ensemble_fits` : A vector of neural network Matrix outputs used to create the ensemble/weak solution.

* `ensemble_inputs` : A vector of Matrix inputs fed to the network while creating the ensemble/weak solution during validation.

* `numensemble` : Number of sample solutions used while creating the ensemble solution and therefore the weak solution.

* `training_sets` : A vector of Matrix inputs that is used during training the SDEPINN
                    (different for strong vs. weak training, losses as mentioned in `strong_loss` arg for `NNSDE`).

* `dataset_training_set` : A vector of Matrix inputs that is used during training `sdephi` while using the `moment_loss`.
                           The inputs are created from the dataset, used for inverse solves. `nothing` by default.

# Notes

* `estimated_nn_params` and `estimated_de_params` refer to the probabilistic estimates of the neural network and differential equation parameters, respectively.
"""

@concrete struct SDEsol
    original
    rode_solution
    estimated_sol::Vector{<:Vector{<:Particles}}
    timepoints::Vector{<:Number}
    estimated_params::Union{Nothing, Vector{<:Number}}
    ensemble_fits::Vector{<:Matrix{<:Number}}
    ensemble_inputs::Vector
    numensemble::Int64
    training_sets::Vector
    dataset_training_sets::Union{Nothing, Vector}
end

function SciMLBase.__solve(
        prob::SciMLBase.AbstractSDEProblem,
        alg::NNSDE,
        args...;
        dt = nothing,
        timeseries_errors = true,
        save_everystep = true,
        adaptive = false,
        abstol = 1.0f-6,
        reltol = 1.0f-3,
        verbose = false,
        saveat = nothing,
        maxiters = nothing,
        tstops = nothing
)
    (; u0, tspan, f, g, p) = prob
    # rescaling tspan discretization so KKL expansion can be applied for loss formulation
    tspan_scale = tspan ./ tspan[end]
    if dt !== nothing
        dt = dt / abs(tspan_scale[2] - tspan_scale[1])
    end
    t0 = tspan_scale[1]

    # sub_batch is basically the number of samples/n_samples of the truncated KKL's RV basis.
    # For weak training: higher sub_batch corresponds with a narrower confidence band/ increased certainty in the Weak solution.
    # For strong training: it means more strong paths to train over.
    # weak loss-> weak training is default solve mode.
    (; param_estim, sub_batch, strong_loss, moment_loss,
        chain, opt, autodiff, init_params, batch,
        additional_loss, dataset, numensemble, data_sub_batch) = alg
    n_z = chain[1].in_dims - 1
    sde_phi, init_params = generate_phi(chain, t0, u0, init_params)

    (recursive_eltype(init_params) <: Complex && alg.strategy isa QuadratureTraining) &&
        error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")

    init_params = if alg.param_estim
        ComponentArray(; depvar = init_params, p)
    else
        ComponentArray(; depvar = init_params)
    end

    @assert !isinplace(prob) "The NNSDE solver only supports out-of-place SDE definitions, i.e. du=f(u,p,t) + g(u,p,t)*dW(t)"

    strategy = if alg.strategy === nothing
        if dt !== nothing
            GridTraining(dt)
        else
            QuadratureTraining(; quadrature_alg = QuadGKJL(),
                reltol = convert(eltype(u0), reltol), abstol = convert(eltype(u0), abstol),
                maxiters, batch = 0)
        end
    else
        alg.strategy
    end

    # train_type is weak (expectation based loss + random sets of z_i for all timepoints) by default
    # use strong_loss = true for strong loss (pathwise total loss summation + same z_i for all timepoints)
    train_type = strong_loss ? sum : mean
    inner_f,
    training_sets = generate_loss(
        strategy, sde_phi, f, g, autodiff, tspan_scale, n_z,
        sub_batch, train_type, p, batch, param_estim)

    if isempty(dataset) && param_estim && isnothing(additional_loss)
        error("Dataset or an additional loss is required for Inverse problems performing Parameter Estimation.")
    end

    # allow losses that use dataset to be used in non parameter estimation cases.
    if !isempty(dataset)
        if (length(dataset) < 2 || !(dataset isa Vector{<:Vector}))
            error("Invalid dataset. The dataset would be a timeseries (x̂,t) where x̂ is of type: Vector{<:Vector{<:AbstractFloat}} and t is type: Vector{AbstractFloat}.")
        end

        EM_L2loss, dataset_training_sets = generate_EM_L2loss(dataset, f, g)

        if moment_loss
            # min batch for L2 mean is sub samples of the dataset
            data_sub_batch = max(data_sub_batch, length(dataset[1]))
            DataMoments_loss,
            dataset_training_sets = generate_DataMoments_loss(
                dataset, n_z, sde_phi, f, g,
                autodiff, p, param_estim, data_sub_batch, train_type)
        end
    else
        dataset_training_sets = nothing
    end

    # Creates OptimizationFunction Object from total_loss
    function total_loss(θ, _)
        phys_loss = inner_f(θ, sde_phi)
        if additional_loss !== nothing
            phys_loss = phys_loss + additional_loss(sde_phi, θ)
        end
        if param_estim == true && !isnothing(dataset)
            phys_loss = phys_loss + EM_L2loss(θ, sde_phi)
        end
        if param_estim && moment_loss
            phys_loss = phys_loss + DataMoments_loss(θ, sde_phi)
        end
        if tstops !== nothing
            num_tstops_points = length(tstops)
            tstops_loss_func = evaluate_tstops_loss(
                sde_phi, f, g, autodiff, tstops, n_z, sub_batch,
                train_type, p, batch, param_estim)
            tstops_loss = tstops_loss_func(θ, sde_phi)
            if strategy isa GridTraining
                num_original_points = length(tspan_scale[1]:(strategy.dx):tspan_scale[2])
            elseif strategy isa Union{WeightedIntervalTraining, StochasticTraining}
                num_original_points = strategy.points
            else
                return phys_loss + tstops_loss
            end
            total_original_loss = phys_loss * num_original_points
            total_tstops_loss = tstops_loss * num_tstops_points
            total_points = num_original_points + num_tstops_points
            phys_loss = (total_original_loss + total_tstops_loss) / total_points
            return phys_loss
        end
        return phys_loss
    end

    opt_algo = ifelse(strategy isa QuadratureTraining, AutoForwardDiff(), AutoZygote())
    optf = OptimizationFunction(total_loss, opt_algo)

    plen = maxiters === nothing ? 6 : ndigits(maxiters)
    callback = function (p, l)
        if verbose
            if maxiters === nothing
                @printf("[NNSDE]\tIter: [%*d]\tLoss: %g\n", plen, p.iter, l)
            else
                @printf("[NNSDE]\tIter: [%*d/%d]\tLoss: %g\n", plen, p.iter, maxiters, l)
            end
        end
        return l < abstol
    end

    optprob = OptimizationProblem(optf, init_params)
    res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

    #solutions at timepoints
    if saveat isa Number
        ts = tspan_scale[1]:saveat:tspan_scale[2]
    elseif saveat isa AbstractArray
        ts = saveat
    elseif dt !== nothing
        ts = tspan_scale[1]:dt:tspan_scale[2]
    elseif save_everystep
        ts = range(tspan_scale[1], tspan_scale[2], length = 100)
    else
        ts = [tspan_scale[1], tspan_scale[2]]
    end
    ts = collect(ts)

    # validation ensemble creation for all timepoints -> reflects learnt dynamics of the SDE solution's Expectation.
    validation_inputs = add_rand_coeff(ts, n_z, numensemble)
    u = [sde_phi(input, res.u) for input in validation_inputs]
    n_output = chain[end].out_dims
    sol_parts = [[Particles(u[i][j, :]) for i in eachindex(ts)] for j in 1:n_output]

    estimated_sde_parameters = param_estim ? collect(res.u.p) : nothing

    sol = SciMLBase.build_solution(prob, alg, ts, sol_parts; dense = true,
        interp = NNSDEInterpolation(sde_phi, res.u), calculate_error = false,
        retcode = ReturnCode.Success, original = res, resid = res.objective)

    SciMLBase.has_analytic(prob.f) &&
        SciMLBase.calculate_solution_errors!(
            sol; timeseries_errors = true, dense_errors = false)

    # separate solution realisations, inputs can be accessed via ensembles, ensemble_inputs
    return SDEsol(
        res, sol, sol_parts, ts, estimated_sde_parameters, u, validation_inputs,
        numensemble, training_sets, dataset_training_sets)
end
